Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic safetensors conversion when lacking these files #29390

Merged
merged 5 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile

Expand Down Expand Up @@ -3207,9 +3208,38 @@ def from_pretrained(
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.

if resolved_archive_file is not None:
# If the PyTorch file was found, check if there is a safetensors file on the repository
# If there is no safetensors file on the repositories, start an auto conversion
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"token": token,
}
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"local_files_only": local_files_only,
"user_agent": user_agent,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
cls._auto_conversion = Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs=cached_file_kwargs,
)
cls._auto_conversion.start()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wauplin curious if you have a better idea in mind to have access to the thread started here; I don't need to join it during runtime, I'm only attributing it to the class here so that I can access it within the test files (but not super keen on modifying internals just for the tests to be simpler ...)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik I'm not shocked by having a cls._auto_conversion attribute TBH. Though a solution to get rid of it is to give a name to the thread. Something like that:

Thread(
    target=auto_conversion,
    args=(pretrained_model_name_or_path,),
    kwargs=cached_file_kwargs,
    name="Thread-autoconversion-{<unique id here>}",
).start()

and then in the tests:

for thread in threading.enumerate():
    print(thread.name)
# ...
# Thread-autoconversion-0

Thread names don't have to be unique BTW (they have a thread id anyway). But I think it's best to at least assign a unique number to the name.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But it's quite hacky IMO. In a simple case it should work fine but if you start to have several threads / parallel tests, it might get harder to be 100% sure the thread you've started is indeed the one you retrieve in the test logic.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah here it's really only for testing and I don't want to depend on a flaky time.sleep or something so ensuring that the thread joins first is optimal. The thread name is actually much better IMO, I'll implement that! Thanks a lot!

else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
# We try those to give a helpful error message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
Expand Down
28 changes: 27 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,7 +1428,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
bot_opened_pr_title = None

for discussion in discussions:
if discussion.author == "SFconvertBot":
if discussion.author == "SFconvertbot":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

➕ on @julien-c's comment, have had feedback that this is not explicit enough.

Suggested change
if discussion.author == "SFconvertbot":
if discussion.author == "HuggingFaceOfficialSafetensorConverter":

bot is scary for some 😅

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can't change the account name now

but we will think of a way to make it clearer in the UI that it's a "official bot"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good 👍🏻

bot_opened_pr = True
bot_opened_pr_title = discussion.title

Expand All @@ -1451,6 +1451,32 @@ def test_safetensors_on_the_fly_specific_revision(self):
with self.assertRaises(EnvironmentError):
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")

def test_absence_of_safetensors_triggers_conversion(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
initial_model = BertModel(config)

# Push a model on `main`
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)

initial_model = BertModel.from_pretrained(self.repo_name, token=self.token)
BertModel._auto_conversion.join()

with self.subTest("PR was open with the safetensors account"):
discussions = self.api.get_repo_discussions(self.repo_name)

bot_opened_pr = None
bot_opened_pr_title = None

for discussion in discussions:
if discussion.author == "SFconvertbot":
bot_opened_pr = True
bot_opened_pr_title = discussion.title

self.assertTrue(bot_opened_pr)
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")


@require_torch
@is_staging_test
Expand Down
Loading