Skip to content

Commit

Permalink
[core/ FEAT] Add the possibility to push custom tags using `PreTrai…
Browse files Browse the repository at this point in the history
…nedModel` itself (#28405)

* v1 tags

* remove unneeded conversion

* v2

* rm unneeded warning

* add more utility methods

* Update src/transformers/utils/hub.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* more enhancements

* oops

* merge tags

* clean up

* revert unneeded change

* add extensive docs

* more docs

* more kwargs

* add test

* oops

* fix test

* Update src/transformers/modeling_utils.py

Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>

* Update src/transformers/utils/hub.py

Co-authored-by: Lucain <lucainp@gmail.com>

* Update src/transformers/modeling_utils.py

* Update src/transformers/trainer.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add more conditions

* more logic

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: Lucain <lucainp@gmail.com>
Co-authored-by: Omar Sanseviero <osanseviero@gmail.com>
  • Loading branch information
4 people committed Jan 15, 2024
1 parent 64bdbd8 commit 1b9a2e4
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 1 deletion.
61 changes: 60 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -1172,6 +1172,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
model_tags = None

_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
Expand Down Expand Up @@ -1252,6 +1254,38 @@ def _backward_compatibility_gradient_checkpointing(self):
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")

def add_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
not overwrite existing tags in the model.
Args:
tags (`Union[List[str], str]`):
The desired tags to inject in the model
Examples:
```python
from transformers import AutoModel
model = AutoModel.from_pretrained("bert-base-cased")
model.add_model_tags(["custom", "custom-bert"])
# Push the model to your namespace with the name "my-custom-bert".
model.push_to_hub("my-custom-bert")
```
"""
if isinstance(tags, str):
tags = [tags]

if self.model_tags is None:
self.model_tags = []

for tag in tags:
if tag not in self.model_tags:
self.model_tags.append(tag)

@classmethod
def _from_config(cls, config, **kwargs):
"""
Expand Down Expand Up @@ -2212,6 +2246,7 @@ def save_pretrained(
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
"""
use_auth_token = kwargs.pop("use_auth_token", None)
ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)

if use_auth_token is not None:
warnings.warn(
Expand Down Expand Up @@ -2438,6 +2473,14 @@ def save_pretrained(
)

if push_to_hub:
# Eventually create an empty model card
model_card = create_and_tag_model_card(
repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
)

# Update model card if needed:
model_card.save(os.path.join(save_directory, "README.md"))

self._upload_modified_files(
save_directory,
repo_id,
Expand All @@ -2446,6 +2489,22 @@ def save_pretrained(
token=token,
)

@wraps(PushToHubMixin.push_to_hub)
def push_to_hub(self, *args, **kwargs):
tags = self.model_tags if self.model_tags is not None else []

tags_kwargs = kwargs.get("tags", [])
if isinstance(tags_kwargs, str):
tags_kwargs = [tags_kwargs]

for tag in tags_kwargs:
if tag not in tags:
tags.append(tag)

if tags:
kwargs["tags"] = tags
return super().push_to_hub(*args, **kwargs)

def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Expand Down
21 changes: 21 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3581,6 +3581,15 @@ def create_model_card(
library_name = ModelCard.load(model_card_filepath).data.get("library_name")
is_peft_library = library_name == "peft"

# Append existing tags in `tags`
existing_tags = ModelCard.load(model_card_filepath).data.tags
if tags is not None and existing_tags is not None:
if isinstance(tags, str):
tags = [tags]
for tag in existing_tags:
if tag not in tags:
tags.append(tag)

training_summary = TrainingSummary.from_trainer(
self,
language=language,
Expand Down Expand Up @@ -3699,6 +3708,18 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
if not self.is_world_process_zero():
return

# Add additional tags in the case the model has already some tags and users pass
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
# from all models since Trainer does not call `model.push_to_hub`.
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
# If it is a string, convert it to a list
if isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"]]

for model_tag in self.model.model_tags:
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)

self.create_model_card(model_name=model_name, **kwargs)

# Wait for the current upload to be finished.
Expand Down
51 changes: 51 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
ModelCard,
ModelCardData,
constants,
create_branch,
create_commit,
Expand Down Expand Up @@ -762,6 +764,7 @@ def push_to_hub(
safe_serialization: bool = True,
revision: str = None,
commit_description: str = None,
tags: Optional[List[str]] = None,
**deprecated_kwargs,
) -> str:
"""
Expand Down Expand Up @@ -795,6 +798,8 @@ def push_to_hub(
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
tags (`List[str]`, *optional*):
List of tags to push on the Hub.
Examples:
Expand All @@ -811,6 +816,7 @@ def push_to_hub(
```
"""
use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
Expand Down Expand Up @@ -855,6 +861,11 @@ def push_to_hub(
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
)

# Create a new empty model card and eventually tag it
model_card = create_and_tag_model_card(
repo_id, tags, token=token, ignore_metadata_errors=ignore_metadata_errors
)

if use_temp_dir is None:
use_temp_dir = not os.path.isdir(working_dir)

Expand All @@ -864,6 +875,9 @@ def push_to_hub(
# Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)

# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))

return self._upload_modified_files(
work_dir,
repo_id,
Expand Down Expand Up @@ -1081,6 +1095,43 @@ def extract_info_from_url(url):
return {"repo": cache_repo, "revision": revision, "filename": filename}


def create_and_tag_model_card(
repo_id: str,
tags: Optional[List[str]] = None,
token: Optional[str] = None,
ignore_metadata_errors: bool = False,
):
"""
Creates or loads an existing model card and tags it.
Args:
repo_id (`str`):
The repo_id where to look for the model card.
tags (`List[str]`, *optional*):
The list of tags to add in the model card
token (`str`, *optional*):
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
ignore_metadata_errors (`str`):
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
the process. Use it at your own risk.
"""
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id, token=token, ignore_metadata_errors=ignore_metadata_errors)
except EntryNotFoundError:
# Otherwise create a simple model card from template
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
model_card = ModelCard.from_template(card_data, model_description=model_description)

if tags is not None:
for model_tag in tags:
if model_tag not in model_card.data.tags:
model_card.data.tags.append(model_tag)

return model_card


def clean_files_for(file):
"""
Remove, if they exist, file, file.json and file.lock
Expand Down
27 changes: 27 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,11 @@ def tearDownClass(cls):
except HTTPError:
pass

try:
delete_repo(token=cls._token, repo_id="test-dynamic-model-with-tags")
except HTTPError:
pass

@unittest.skip("This test is flaky")
def test_push_to_hub(self):
config = BertConfig(
Expand Down Expand Up @@ -1522,6 +1527,28 @@ def test_push_to_hub_dynamic_model(self):
new_model = AutoModel.from_config(config, trust_remote_code=True)
self.assertEqual(new_model.__class__.__name__, "CustomModel")

def test_push_to_hub_with_tags(self):
from huggingface_hub import ModelCard

new_tags = ["tag-1", "tag-2"]

CustomConfig.register_for_auto_class()
CustomModel.register_for_auto_class()

config = CustomConfig(hidden_size=32)
model = CustomModel(config)

self.assertTrue(model.model_tags is None)

model.add_model_tags(new_tags)

self.assertTrue(model.model_tags == new_tags)

model.push_to_hub("test-dynamic-model-with-tags", token=self._token)

loaded_model_card = ModelCard.load(f"{USER}/test-dynamic-model-with-tags")
self.assertEqual(loaded_model_card.data.tags, new_tags)


@require_torch
class AttentionMaskTester(unittest.TestCase):
Expand Down

0 comments on commit 1b9a2e4

Please sign in to comment.