Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core/ FEAT] Add the possibility to push custom tags using PreTrainedModel itself #28405

Merged
merged 29 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
dd5cbe3
v1 tags
younesbelkada Jan 9, 2024
f78ed31
remove unneeded conversion
younesbelkada Jan 9, 2024
da00274
v2
younesbelkada Jan 9, 2024
1180585
rm unneeded warning
younesbelkada Jan 9, 2024
3485b10
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 9, 2024
4b82255
add more utility methods
younesbelkada Jan 9, 2024
4c7806e
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
8b89796
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
e73dc7b
Update src/transformers/utils/hub.py
younesbelkada Jan 9, 2024
fbef2de
more enhancements
younesbelkada Jan 9, 2024
0e4daad
oops
younesbelkada Jan 9, 2024
c19e751
merge tags
younesbelkada Jan 9, 2024
eb93371
clean up
younesbelkada Jan 9, 2024
1fe93b3
revert unneeded change
younesbelkada Jan 9, 2024
a24ad9b
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 10, 2024
6cfd6f5
add extensive docs
younesbelkada Jan 10, 2024
40a1d4b
more docs
younesbelkada Jan 10, 2024
dc31941
more kwargs
younesbelkada Jan 10, 2024
acd676b
add test
younesbelkada Jan 10, 2024
db3197d
oops
younesbelkada Jan 10, 2024
f14cf93
fix test
younesbelkada Jan 10, 2024
31117f4
Update src/transformers/modeling_utils.py
younesbelkada Jan 10, 2024
36f2cb7
Update src/transformers/utils/hub.py
younesbelkada Jan 10, 2024
514f13b
Update src/transformers/modeling_utils.py
younesbelkada Jan 10, 2024
b3d5900
Update src/transformers/trainer.py
younesbelkada Jan 15, 2024
22d3412
Update src/transformers/modeling_utils.py
younesbelkada Jan 15, 2024
85584ae
Merge remote-tracking branch 'upstream/main' into set-custom-tag
younesbelkada Jan 15, 2024
1e3fc1e
add more conditions
younesbelkada Jan 15, 2024
59738c6
more logic
younesbelkada Jan 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 76 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.hub import convert_file_size_to_int, create_and_tag_model_card, get_checkpoint_shard_files
from .utils.import_utils import (
ENV_VARS_TRUE_VALUES,
is_sagemaker_mp_enabled,
Expand Down Expand Up @@ -1159,6 +1159,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
config_class = None
base_model_prefix = ""
main_input_name = "input_ids"
model_tags = None

_auto_class = None
_no_split_modules = None
_skip_keys_device_placement = None
Expand Down Expand Up @@ -1239,6 +1241,62 @@ def _backward_compatibility_gradient_checkpointing(self):
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")

def add_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Add all tags in `tags` to `model_tags`.

Args:
tags (`Union[List[str], str]`):
The desired tags to inject in the model
"""
if isinstance(tags, str):
tags = [tags]

if self.model_tags is None:
self.model_tags = []

for tag in tags:
if tag not in self.model_tags:
self.model_tags.append(tag)

def set_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Manually force-set the model tags with `tags`

Args:
tags (`Union[List[str], str]`):
The desired tags to inject in the model
"""
if isinstance(tags, str):
tags = [tags]

self.model_tags = tags

def reset_model_tags(self) -> None:
r"""
Manually reset the model tags with an empty list
"""
if self.model_tags is not None:
self.model_tags = []

def remove_model_tags(self, tags: Union[List[str], str]) -> None:
r"""
Manually remove all elements of `tags` in the model tags

Args:
tags (`Union[List[str], str]`):
The desired tags to remove from the model
"""
if isinstance(tags, str):
tags = [tags]

if self.model_tags is None:
return

for tag in tags:
if tag in self.model_tags:
self.model_tags.remove(tag)

@classmethod
def _from_config(cls, config, **kwargs):
"""
Expand Down Expand Up @@ -2425,6 +2483,12 @@ def save_pretrained(
)

if push_to_hub:
# Eventually create an empty model card
model_card = create_and_tag_model_card(repo_id, self.model_tags)

# Update model card if needed:
model_card.save(os.path.join(save_directory, "README.md"))

self._upload_modified_files(
save_directory,
repo_id,
Expand All @@ -2433,6 +2497,17 @@ def save_pretrained(
token=token,
)

@wraps(PushToHubMixin.push_to_hub)
def push_to_hub(self, *args, **kwargs):
if "tags" not in kwargs:
kwargs["tags"] = self.model_tags
elif "tags" in kwargs and self.model_tags is not None:
for model_tag in self.model_tags:
# merge the tags together
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)
return super().push_to_hub(*args, **kwargs)
Copy link
Collaborator

@amyeroberts amyeroberts Jan 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current if/lse logic means:

  • if self.model_tags is None and tags aren't passed in kwargs['tags'] is set to None. Do we want this or no tags kwarg at all?
    * tags are only used if self.model_tags is not None. - just realised I'm wrong as kwargs are passed otherwise 🙃

Another Q

  • Is it possible kwargs["tags"] is a string here too?
Suggested change
if "tags" not in kwargs:
kwargs["tags"] = self.model_tags
elif "tags" in kwargs and self.model_tags is not None:
for model_tag in self.model_tags:
# merge the tags together
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)
return super().push_to_hub(*args, **kwargs)
tags = self.model_tags if self.model_tags is not None else []
for tag in kwargs.get("tags", []):
if tag not in tags:
tags.append(tag)
if tags:
kwargs["tags"] = tags
return super().push_to_hub(*args, **kwargs)

Copy link
Contributor Author

@younesbelkada younesbelkada Jan 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your suggestion sounds great, also

Is it possible kwargs["tags"] is a string here too?

It can be strings, let me adapt a bit the logic after accepting your suggestion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just took care of the str case in 1e3fc1e


def get_memory_footprint(self, return_buffers=True):
r"""
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
Expand Down
12 changes: 12 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,6 +3684,18 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
if not self.is_world_process_zero():
return

# Add additional tags in the case the model has already some tags and users pass
# "tags" argument to `push_to_hub` so that trainer automatically handles internal tags
# from all models since Trainer does not call `model.push_to_hub`.
if "tags" in kwargs and getattr(self.model, "model_tags", None) is not None:
# If it is a string, convert it to an array
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(kwargs["tags"], str):
kwargs["tags"] = [kwargs["tags"]]

for model_tag in self.model.model_tags:
if model_tag not in kwargs["tags"]:
kwargs["tags"].append(model_tag)

self.create_model_card(model_name=model_name, **kwargs)

# Wait for the current upload to be finished.
Expand Down
32 changes: 32 additions & 0 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from huggingface_hub import (
_CACHED_NO_EXIST,
CommitOperationAdd,
ModelCard,
ModelCardData,
constants,
create_branch,
create_commit,
Expand Down Expand Up @@ -762,6 +764,7 @@ def push_to_hub(
safe_serialization: bool = True,
revision: str = None,
commit_description: str = None,
tags: List[str] = None,
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
**deprecated_kwargs,
) -> str:
"""
Expand Down Expand Up @@ -795,6 +798,8 @@ def push_to_hub(
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
tags (`List[str]`, *optional*):
List of tags to push on the Hub.

Examples:

Expand Down Expand Up @@ -855,6 +860,9 @@ def push_to_hub(
repo_id, private=private, token=token, repo_url=repo_url, organization=organization
)

# Create a new empty model card and eventually tag it
model_card = create_and_tag_model_card(repo_id, tags)

if use_temp_dir is None:
use_temp_dir = not os.path.isdir(working_dir)

Expand All @@ -864,6 +872,9 @@ def push_to_hub(
# Save all files.
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)

# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))

return self._upload_modified_files(
work_dir,
repo_id,
Expand Down Expand Up @@ -1081,6 +1092,27 @@ def extract_info_from_url(url):
return {"repo": cache_repo, "revision": revision, "filename": filename}


def create_and_tag_model_card(repo_id: str, tags: Optional[List[str]] = None):
"""
Creates a dummy model card and tags it.
"""
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id)
except EntryNotFoundError:
# Otherwise create a simple model card from template
model_description = "This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData(tags=[] if tags is None else tags, library_name="transformers")
model_card = ModelCard.from_template(card_data, model_description=model_description)

if tags is not None:
for model_tag in tags:
if model_tag not in model_card.data.tags:
model_card.data.tags.append(model_tag)

return model_card


def clean_files_for(file):
"""
Remove, if they exist, file, file.json and file.lock
Expand Down
Loading