Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core/ FEAT] Add the possibility to push custom tags using PreTrainedModel itself #28405

Merged
merged 29 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dd5cbe3
v1 tags
younesbelkada Jan 9, 2024
f78ed31
remove unneeded conversion
younesbelkada Jan 9, 2024
da00274
v2
younesbelkada Jan 9, 2024
1180585
rm unneeded warning
younesbelkada Jan 9, 2024
3485b10
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 9, 2024
4b82255
add more utility methods
younesbelkada Jan 9, 2024
4c7806e
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
8b89796
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
e73dc7b
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
fbef2de
more enhancements
younesbelkada Jan 9, 2024
0e4daad
oops
younesbelkada Jan 9, 2024
c19e751
merge tags
younesbelkada Jan 9, 2024
eb93371
clean up
younesbelkada Jan 9, 2024
1fe93b3
revert unneeded change
younesbelkada Jan 9, 2024
a24ad9b
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 10, 2024
6cfd6f5
add extensive docs
younesbelkada Jan 10, 2024
40a1d4b
more docs
younesbelkada Jan 10, 2024
dc31941
more kwargs
younesbelkada Jan 10, 2024
acd676b
add test
younesbelkada Jan 10, 2024
db3197d
oops
younesbelkada Jan 10, 2024
f14cf93
fix test
younesbelkada Jan 10, 2024
31117f4
Update src/transformers/modeling_utils.py
younesbelkada Jan 10, 2024
36f2cb7
Update src/transformers/utils/hub.py
younesbelkada Jan 10, 2024
514f13b
Update src/transformers/modeling_utils.py
younesbelkada Jan 10, 2024
b3d5900
Update src/transformers/trainer.py
younesbelkada Jan 15, 2024
22d3412
Update src/transformers/modeling_utils.py
younesbelkada Jan 15, 2024
85584ae
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 15, 2024
1e3fc1e
add more conditions
younesbelkada Jan 15, 2024
59738c6
more logic
younesbelkada Jan 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -1172,6 +1172,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
model_tags = None

_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
Expand Down Expand Up @@ -1252,6 +1254,38 @@ def _backward_compatibility_gradient_checkpointing(self):
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")

def add_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
not overwrite existing tags in the model.

Args:
tags (`Union[List[str], str]`):
The desired tags to inject in the model

Examples:

```python
from transformers import AutoModel

model = AutoModel.from_pretrained("bert-base-cased")

model.add_model_tags(["custom", "custom-bert"])

# Push the model to your namespace with the name "my-custom-bert".
model.push_to_hub("my-custom-bert")
```
"""
if isinstance(tags, str):
tags = [tags]

if self.model_tags is None:
self.model_tags = []

for tag in tags:
if tag not in self.model_tags:
self.model_tags.append(tag)

@classmethod
def _from_config(cls, config, **kwargs):
"""
Expand Down Expand Up @@ -2212,6 +2246,7 @@ def save_pretrained(
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)

if use_auth_token is not None:
warnings.warn(
Expand Down Expand Up @@ -2438,6 +2473,14 @@ def save_pretrained(
)

if push_to_hub:
# Eventually create an empty model card
model_card = create_and_tag_model_card(
repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
)

# Update model card if needed:
model_card.save(os.path.join(save_directory, "README.md"))

self._upload_modified_files(
save_directory,
repo_id,
Expand All @@ -2446,6 +2489,22 @@ def save_pretrained(
token=token,
)

@wraps(PushToHubMixin.push_to_hub)
def push_to_hub(self, *args, **kwargs):
tags = self.model_tags if self.model_tags is not None else []

tags_kwargs = kwargs.get("tags", [])
if isinstance(tags_kwargs, str):
tags_kwargs = [tags_kwargs]

for tag in tags_kwargs:
if tag not in tags:
tags.append(tag)

if tags:
kwargs["tags"] = tags
return super().push_to_hub(*args, **kwargs)

def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,6 +3581,15 @@ def create_model_card(
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
is_peft_library = library_name == "peft"

# Append existing tags in `tags`
existing_tags = ModelCard.load(model_card_filepath).data.tags
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This handles a corner case I just found out for Trainers, if you push into a repo that has already some tags, it will overwrite the existing tags with new tags. This block simply cirumvents this by instead of overwriting new tags it will append new tags from existing tags

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if this is ok

ModelCard.load(model_card_filepath).data.tags

can give None if the model card is empty

from huggingface_hub import ModelCard
ModelCard.load("ybelkada/test-empty-model-case").data.tags
>>> None

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look OK to me - only comment is that existing_tags might be None, so you'll need to convert it to a list before adding tags if ModelCard.load(model_card_filepath).data.tags doesn't have anything set

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed in 59738c6 a simple workaround, I think that way we don't need to convert it to a list as we care about that corner case only if tags are explicilty passed + there are some existing tags already

if tags is not None:
if isinstance(tags, str):
tags = [tags]
for tag in existing_tags:
if tag not in tags:
tags.append(tag)

training_summary = TrainingSummary.from_trainer(
self,
language=language,
Expand Down Expand Up @@ -3699,6 +3708,18 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
if not self.is_world_process_zero():
return

# Add additional tags in the case the model has already some tags and users pass
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
# from all models since Trainer does not call `model.push_to_hub`.
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
# If it is a string, convert it to a list
if isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"]]

for model_tag in self.model.model_tags:
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)

self.create_model_card(model_name=model_name, **kwargs)

# Wait for the current upload to be finished.
Expand Down
51 changes: 51 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
ModelCard,
ModelCardData,
constants,
create_branch,
create_commit,
Expand Down Expand Up @@ -762,6 +764,7 @@ def push_to_hub(
safe_serialization: bool = True,
revision: str = None,
commit_description: str = None,
tags: Optional[List[str]] = None,
**deprecated_kwargs,
) -> str:
"""
Expand Down Expand Up @@ -795,6 +798,8 @@ def push_to_hub(
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
tags (`List[str]`, *optional*):
List of tags to push on the Hub.

Examples:

Expand All @@ -811,6 +816,7 @@ def push_to_hub(
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
Expand Down Expand Up @@ -855,6 +861,11 @@ def push_to_hub(
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
)

# Create a new empty model card and eventually tag it
model_card = create_and_tag_model_card(
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors
)

if use_temp_dir is None:
use_temp_dir = not os.path.isdir(working_dir)

Expand All @@ -864,6 +875,9 @@ def push_to_hub(
# Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)

# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))

return self._upload_modified_files(
work_dir,
repo_id,
Expand Down Expand Up @@ -1081,6 +1095,43 @@ def extract_info_from_url(url):
return {"repo": cache_repo, "revision": revision, "filename": filename}


def create_and_tag_model_card(
repo_id: str,
tags: Optional[List[str]] = None,
token: Optional[str] = None,
ignore_metadata_errors: bool = False,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin @amyeroberts fine to assume that repo_type will always be set to "model" in ModelCard.load() (inspecting at that methods signature)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes definitely, repo_type="model" is implicit when using ModelCard.
(the parameter exists because ModelCard inherits from the generic RepoCard. Not a perfect design-choice but it's fine :) )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok perfect! let's keep it like that then

):
"""
Creates or loads an existing model card and tags it.

Args:
repo_id (`str`):
The repo_id where to look for the model card.
tags (`List[str]`, *optional*):
The list of tags to add in the model card
token (`str`, *optional*):
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
ignore_metadata_errors (`str`):
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
the process. Use it at your own risk.
Comment on lines +1112 to +1116
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied this from ModelCard.load_card docstring

"""
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
except EntryNotFoundError:
# Otherwise create a simple model card from template
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
model_card = ModelCard.from_template(card_data, model_description=model_description)

if tags is not None:
for model_tag in tags:
if model_tag not in model_card.data.tags:
model_card.data.tags.append(model_tag)

return model_card


def clean_files_for(file):
"""
Remove, if they exist, file, file.json and file.lock
Expand Down
27 changes: 27 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
except HTTPError:
pass

@unittest.skip("This test is flaky")
def test_push_to_hub(self):
config = BertConfig(
Expand Down Expand Up @@ -1522,6 +1527,28 @@ def test_push_to_hub_dynamic_model(self):
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "CustomModel")

def test_push_to_hub_with_tags(self):
from huggingface_hub import ModelCard

new_tags = ["tag-1", "tag-2"]

CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class()

config = CustomConfig(hidden_size=32)
model = CustomModel(config)

self.assertTrue(model.model_tags is None)

model.add_model_tags(new_tags)

self.assertTrue(model.model_tags == new_tags)

model.push_to_hub("test-dynamic-model-with-tags", token=self._token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it problematic to have a fixed name for the repo_id? This could lead to conflicts if several CIs are ran in parallel. Using a unique id (uuid) would avoid that.
(not an expert of transformers's CI though. It's just that I've experienced similar problem in huggingface_hub's CI in the past)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm good point, we do the same things for all tests above so perhaps it is fine, I will let @ydshieh comment on this if he has more insights than me


loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
self.assertEqual(loaded_model_card.data.tags, new_tags)


@require_torch
class AttentionMaskTester(unittest.TestCase):
Expand Down
Loading